{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "view-in-github"
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/NeuromatchAcademy/course-content/blob/master/tutorials/W3D3_NetworkCausality/student/W3D3_Tutorial1.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "# Neuromatch Academy 2020: Week 3 Day 5, Tutorial 1\n",
    "\n",
    "# Causality Day: Interventions\n",
    "\n",
    "**Content creators**: Ari Benjamin, Tony Liu, Konrad Kording\n",
    "\n",
    "**Content reviewers**: Mike X Cohen, Madineh Sarvestani, Ella Batty, Michael Waskom\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "---\n",
    "#Tutorial Objectives\n",
    "\n",
    "We list our overall day objectives below, with the sections we will focus on in this notebook in bold:\n",
    "\n",
    "1.   **Master definitions of causality**\n",
    "2.   **Understand that estimating causality is possible**\n",
    "3.  **Learn 4 different methods and understand when they fail**\n",
    "    1. **Perturbations**\n",
    "    2. Correlations\n",
    "    3. Simultaneous fitting/regression\n",
    "    4. Instrumental variables\n",
    "\n",
    "### Tutorial setting\n",
    "\n",
    "How do we know if a relationship is causal? What does that mean? And how can we estimate causal relationships within neural data?\n",
    "\n",
    "The methods we'll learn today are very general and can be applied to all sorts of data, and in many circumstances.\n",
    "Causal questions are everywhere!\n",
    "\n",
    "\n",
    "### Tutorial 1 Objectives:\n",
    "\n",
    "1.   Simulate a neural system\n",
    "2.   Understand perturbation as a method of estimating causality"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "---\n",
    "# Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "code",
    "colab": {},
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:22:27.147322Z",
     "iopub.status.busy": "2021-05-25T01:22:27.146834Z",
     "iopub.status.idle": "2021-05-25T01:22:27.466469Z",
     "shell.execute_reply": "2021-05-25T01:22:27.465852Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.axes_grid1 import make_axes_locatable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {},
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:22:27.480090Z",
     "iopub.status.busy": "2021-05-25T01:22:27.478876Z",
     "iopub.status.idle": "2021-05-25T01:22:27.582529Z",
     "shell.execute_reply": "2021-05-25T01:22:27.581364Z"
    }
   },
   "outputs": [],
   "source": [
    "#@title Figure settings\n",
    "%config InlineBackend.figure_format = 'retina'\n",
    "plt.style.use(\"https://raw.githubusercontent.com/NeuromatchAcademy/course-content/master/nma.mplstyle\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {},
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:22:27.601512Z",
     "iopub.status.busy": "2021-05-25T01:22:27.595740Z",
     "iopub.status.idle": "2021-05-25T01:22:27.610126Z",
     "shell.execute_reply": "2021-05-25T01:22:27.608802Z"
    }
   },
   "outputs": [],
   "source": [
    "#@title Helper functions\n",
    "\n",
    "\n",
    "def sigmoid(x):\n",
    "    \"\"\"\n",
    "    Compute sigmoid nonlinearity element-wise on x\n",
    "\n",
    "    Args:\n",
    "        x (np.ndarray): the numpy data array we want to transform\n",
    "    Returns\n",
    "        (np.ndarray): x with sigmoid nonlinearity applied\n",
    "    \"\"\"\n",
    "    return 1 / (1 + np.exp(-x))\n",
    "\n",
    "\n",
    "def create_connectivity(n_neurons, random_state=42):\n",
    "    \"\"\"\n",
    "    Generate our nxn causal connectivity matrix.\n",
    "\n",
    "    Args:\n",
    "        n_neurons (int): the number of neurons in our system.\n",
    "        random_state (int): random seed for reproducibility\n",
    "\n",
    "    Returns:\n",
    "        A (np.ndarray): our 0.1 sparse connectivity matrix\n",
    "    \"\"\"\n",
    "    np.random.seed(random_state)\n",
    "    A_0 = np.random.choice([0, 1], size=(n_neurons, n_neurons), p=[0.9, 0.1])\n",
    "\n",
    "    # set the timescale of the dynamical system to about 100 steps\n",
    "    _, s_vals, _ = np.linalg.svd(A_0)\n",
    "    A = A_0 / (1.01 * s_vals[0])\n",
    "\n",
    "    # _, s_val_test, _ = np.linalg.svd(A)\n",
    "    # assert s_val_test[0] < 1, \"largest singular value >= 1\"\n",
    "\n",
    "    return A\n",
    "\n",
    "\n",
    "def see_neurons(A, ax):\n",
    "    \"\"\"\n",
    "    Visualizes the connectivity matrix.\n",
    "\n",
    "    Args:\n",
    "        A (np.ndarray): the connectivity matrix of shape (n_neurons, n_neurons)\n",
    "        ax (plt.axis): the matplotlib axis to display on\n",
    "\n",
    "    Returns:\n",
    "        Nothing, but visualizes A.\n",
    "    \"\"\"\n",
    "    A = A.T  # make up for opposite connectivity\n",
    "    n = len(A)\n",
    "    ax.set_aspect('equal')\n",
    "    thetas = np.linspace(0, np.pi * 2, n, endpoint=False)\n",
    "    x, y = np.cos(thetas), np.sin(thetas),\n",
    "    ax.scatter(x, y, c='k', s=150)\n",
    "\n",
    "    # Renormalize\n",
    "    A = A / A.max()\n",
    "    for i in range(n):\n",
    "        for j in range(n):\n",
    "            if A[i, j] > 0:\n",
    "                ax.arrow(x[i], y[i], x[j] - x[i], y[j] - y[i], color='k', alpha=A[i,j],\n",
    "                         head_width=.15, width = A[i, j] / 25, shape='right',\n",
    "                         length_includes_head=True)\n",
    "    ax.axis('off')\n",
    "\n",
    "\n",
    "def get_perturbed_connectivity_all_neurons(perturbed_X):\n",
    "    \"\"\"\n",
    "    Estimates the connectivity matrix of perturbations through stacked correlations.\n",
    "\n",
    "    Args:\n",
    "        perturbed_X (np.ndarray): the simulated dynamical system X of shape\n",
    "                                  (n_neurons, timesteps)\n",
    "\n",
    "    Returns:\n",
    "        R (np.ndarray): the estimated connectivity matrix of shape\n",
    "                        (n_neurons, n_neurons)\n",
    "    \"\"\"\n",
    "    # select perturbations (P) and outcomes (Outs)\n",
    "    # we perturb the system every over time step, hence the 2 in slice notation\n",
    "    P = perturbed_X[:, ::2]\n",
    "    Outs = perturbed_X[:, 1::2]\n",
    "\n",
    "    # stack perturbations and outcomes into a 2n by (timesteps / 2) matrix\n",
    "    S = np.concatenate([P, Outs], axis=0)\n",
    "\n",
    "    # select the perturbation -> outcome block of correlation matrix (upper right)\n",
    "    R = np.corrcoef(S)[:n_neurons, n_neurons:]\n",
    "\n",
    "    return R\n",
    "\n",
    "\n",
    "def simulate_neurons_perturb(A, timesteps):\n",
    "    \"\"\"\n",
    "    Simulates a dynamical system for the specified number of neurons and timesteps,\n",
    "    BUT every other timestep the activity is clamped to a random pattern of 1s and 0s\n",
    "\n",
    "    Args:\n",
    "        A (np.array): the true connectivity matrix\n",
    "        timesteps (int): the number of timesteps to simulate our system.\n",
    "\n",
    "    Returns:\n",
    "        The results of the simulated system.\n",
    "        - X has shape (n_neurons, timeteps)\n",
    "    \"\"\"\n",
    "    n_neurons = len(A)\n",
    "    X = np.zeros((n_neurons, timesteps))\n",
    "\n",
    "    for t in range(timesteps - 1):\n",
    "\n",
    "        if t % 2 == 0:\n",
    "            X[:, t] = np.random.choice([0, 1], size=n_neurons)\n",
    "\n",
    "        epsilon = np.random.multivariate_normal(np.zeros(n_neurons), np.eye(n_neurons))\n",
    "        X[:, t + 1] = sigmoid(A.dot(X[:, t]) + epsilon)  # we are using helper function sigmoid\n",
    "\n",
    "    return X\n",
    "\n",
    "\n",
    "def plot_connectivity_matrix(A, ax=None):\n",
    "  \"\"\"Plot the (weighted) connectivity matrix A as a heatmap\n",
    "\n",
    "    Args:\n",
    "      A (ndarray): connectivity matrix (n_neurons by n_neurons)\n",
    "      ax: axis on which to display connectivity matrix\n",
    "  \"\"\"\n",
    "  if ax is None:\n",
    "    ax = plt.gca()\n",
    "  lim = np.abs(A).max()\n",
    "  im = ax.imshow(A, vmin=-lim, vmax=lim, cmap=\"coolwarm\")\n",
    "  ax.tick_params(labelsize=10)\n",
    "  ax.xaxis.label.set_size(15)\n",
    "  ax.yaxis.label.set_size(15)\n",
    "  cbar = ax.figure.colorbar(im, ax=ax, ticks=[0], shrink=.7)\n",
    "  cbar.ax.set_ylabel(\"Connectivity Strength\", rotation=90,\n",
    "                     labelpad= 20,va=\"bottom\")\n",
    "  ax.set(xlabel=\"Connectivity from\", ylabel=\"Connectivity to\")\n",
    "\n",
    "\n",
    "def plot_connectivity_graph_matrix(A):\n",
    "  \"\"\"Plot both connectivity graph and matrix side by side\n",
    "\n",
    "    Args:\n",
    "      A (ndarray): connectivity matrix (n_neurons by n_neurons)\n",
    "\n",
    "  \"\"\"\n",
    "  fig, axs = plt.subplots(1, 2, figsize=(10, 5))\n",
    "  see_neurons(A,axs[0])  # we are invoking a helper function that visualizes the connectivity matrix\n",
    "  plot_connectivity_matrix(A)\n",
    "\n",
    "  fig.suptitle(\"Neuron Connectivity\")\n",
    "  plt.show()\n",
    "\n",
    "def plot_neural_activity(X):\n",
    "  \"\"\"Plot first 10 timesteps of neural activity\n",
    "\n",
    "  Args:\n",
    "    X (ndarray): neural activity (n_neurons by timesteps)\n",
    "\n",
    "  \"\"\"\n",
    "  f, ax = plt.subplots()\n",
    "  im = ax.imshow(X[:, :10])\n",
    "  divider = make_axes_locatable(ax)\n",
    "  cax1 = divider.append_axes(\"right\", size=\"5%\", pad=0.15)\n",
    "  plt.colorbar(im, cax=cax1)\n",
    "  ax.set(xlabel='Timestep', ylabel='Neuron', title='Simulated Neural Activity')\n",
    "\n",
    "\n",
    "def plot_true_vs_estimated_connectivity(estimated_connectivity, true_connectivity, selected_neuron=None):\n",
    "  \"\"\"Visualize true vs estimated connectivity matrices\n",
    "\n",
    "  Args:\n",
    "    estimated_connectivity (ndarray): estimated connectivity (n_neurons by n_neurons)\n",
    "    true_connectivity (ndarray): ground-truth connectivity (n_neurons by n_neurons)\n",
    "    selected_neuron (int or None): None if plotting all connectivity, otherwise connectivity\n",
    "      from selected_neuron will be shown\n",
    "\n",
    "  \"\"\"\n",
    "\n",
    "  fig, axs = plt.subplots(1, 2, figsize=(10, 5))\n",
    "\n",
    "  if selected_neuron is not None:\n",
    "    plot_connectivity_matrix(np.expand_dims(estimated_connectivity, axis=1), ax=axs[0])\n",
    "    plot_connectivity_matrix(true_connectivity[:, [selected_neuron]], ax=axs[1])\n",
    "    axs[0].set_xticks([0])\n",
    "    axs[1].set_xticks([0])\n",
    "    axs[0].set_xticklabels([selected_neuron])\n",
    "    axs[1].set_xticklabels([selected_neuron])\n",
    "  else:\n",
    "    plot_connectivity_matrix(estimated_connectivity, ax=axs[0])\n",
    "    plot_connectivity_matrix(true_connectivity, ax=axs[1])\n",
    "\n",
    "  axs[1].set(title=\"True connectivity\")\n",
    "  axs[0].set(title=\"Estimated connectivity\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "---\n",
    "# Section 1: Defining and estimating causality\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 537
    },
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:22:27.618548Z",
     "iopub.status.busy": "2021-05-25T01:22:27.617580Z",
     "iopub.status.idle": "2021-05-25T01:22:27.653101Z",
     "shell.execute_reply": "2021-05-25T01:22:27.652612Z"
    },
    "outputId": "3d6a446c-7595-4022-e20c-b822d815cefe"
   },
   "outputs": [],
   "source": [
    "#@title Video 1: Defining causality\n",
    "# Insert the ID of the corresponding youtube video\n",
    "from IPython.display import YouTubeVideo\n",
    "video = YouTubeVideo(id=\"yiddT2sMbZM\", width=854, height=480, fs=1)\n",
    "print(\"Video available at https://youtu.be/\" + video.id)\n",
    "video"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "\n",
    "Let's think carefully about the statement \"**A causes B**\". To be concrete, let's take two neurons. What does it mean to say that neuron $A$ causes neuron $B$ to fire?\n",
    "\n",
    "The *interventional* definition of causality says that:\n",
    "$$\n",
    "(A \\text{ causes } B) \\Leftrightarrow ( \\text{ If we force }A \\text { to be different, then }B\\text{ changes})\n",
    "$$\n",
    "\n",
    "To determine if $A$ causes $B$ to fire, we can inject current into neuron $A$ and see what happens to $B$.\n",
    "\n",
    "**A mathematical definition of causality**: \n",
    "Over many trials, the average causal effect $\\delta_{A\\to B}$ of neuron $A$ upon neuron $B$ is the average change in neuron $B$'s activity when we set $A=1$ versus when we set $A=0$.\n",
    "\n",
    "\n",
    "$$\n",
    "\\delta_{A\\to B} = \\mathbb{E}[B | A=1] -  \\mathbb{E}[B | A=0] \n",
    "$$\n",
    "\n",
    "Note that this is an average effect. While one can get more sophisticated about conditional effects ($A$ only effects $B$ when it's not refractory, perhaps), we will only consider average effects today.\n",
    "\n",
    "**Relation to a randomized controlled trial (RCT)**:\n",
    "The logic we just described is the logic of a randomized control trial (RCT). If you randomly give 100 people a drug and 100 people a placebo, the effect is the difference in outcomes.\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Exercise 1: Randomized controlled trial for two neurons\n",
    "\n",
    "Let's pretend we can perform a randomized controlled trial for two neurons. Our model will have neuron $A$ synapsing on Neuron $B$:\n",
    "$$B = A + \\epsilon$$\n",
    " where $A$ and $B$ represent the activities of the two neurons and $\\epsilon$ is standard normal noise $\\epsilon\\sim\\mathcal{N}(0,1)$.\n",
    "\n",
    "Our goal is to perturb $A$ and confirm that $B$ changes. \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "both",
    "colab": {},
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:22:27.658284Z",
     "iopub.status.busy": "2021-05-25T01:22:27.657724Z",
     "iopub.status.idle": "2021-05-25T01:22:27.661539Z",
     "shell.execute_reply": "2021-05-25T01:22:27.661047Z"
    }
   },
   "outputs": [],
   "source": [
    "def neuron_B(activity_of_A):\n",
    "  \"\"\"Model activity of neuron B as neuron A activity + noise\n",
    "\n",
    "  Args:\n",
    "    activity_of_A (ndarray): An array of shape (T,) containing the neural activity of neuron A\n",
    "\n",
    "  Returns:\n",
    "    ndarray: activity of neuron B\n",
    "  \"\"\"\n",
    "  noise = np.random.randn(activity_of_A.shape[0])\n",
    "  return activity_of_A + noise\n",
    "\n",
    "np.random.seed(12)\n",
    "\n",
    "# Neuron A activity of zeros\n",
    "A_0 = np.zeros(5000)\n",
    "\n",
    "# Neuron A activity of ones\n",
    "A_1 = np.ones(5000)\n",
    "\n",
    "###########################################################################\n",
    "## TODO for students: Estimate the causal effect of A upon B\n",
    "## Use eq above (difference in mean of B when A=0 vs. A=1)\n",
    "###########################################################################\n",
    "diff_in_means = ...\n",
    "#print(diff_in_means)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "cellView": "both",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "text",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:22:27.666548Z",
     "iopub.status.busy": "2021-05-25T01:22:27.666037Z",
     "iopub.status.idle": "2021-05-25T01:22:27.671800Z",
     "shell.execute_reply": "2021-05-25T01:22:27.671077Z"
    },
    "outputId": "8caa500d-925b-4cc4-bfb3-e8d3b2e133cd"
   },
   "source": [
    "[*Click for solution*](https://github.com/NeuromatchAcademy/course-content/tree/master//tutorials/W3D5_NetworkCausality/solutions/W3D5_Tutorial1_Solution_9ae3afbe.py)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "You should get a difference in means of `0.990719` (so very close to one). "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "---\n",
    "# Section 2: Simulating a system of neurons\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "Can we still estimate causal effects when the neurons are in big networks? This is the main question we will ask today. Let's first create our system, and the rest of today will be spend analyzing it.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 537
    },
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:22:27.676850Z",
     "iopub.status.busy": "2021-05-25T01:22:27.676309Z",
     "iopub.status.idle": "2021-05-25T01:22:27.719748Z",
     "shell.execute_reply": "2021-05-25T01:22:27.718802Z"
    },
    "outputId": "2c8104ee-e9c5-449e-c0e0-0e56c1700105"
   },
   "outputs": [],
   "source": [
    "#@title Video 2: Simulated neural system model\n",
    "# Insert the ID of the corresponding youtube video\n",
    "from IPython.display import YouTubeVideo\n",
    "video = YouTubeVideo(id=\"oPJz49dAuL8\", width=854, height=480, fs=1)\n",
    "print(\"Video available at https://youtu.be/\" + video.id)\n",
    "video"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "**Video correction**: the connectivity graph plots and associated explanations in this and other videos show the wrong direction of connectivity (the arrows should be pointing the opposite direction). This has been fixed in the figures below."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Section 2.1: Our system\n",
    "\n",
    "This section recaps the system described in Video 2 so may be skipped. \n",
    "\n",
    "Our system has N interconnected neurons that affect each other over time. Each neuron at time $t+1$ is a function of the activity of the other neurons from the previous time $t$. \n",
    "\n",
    "Neurons affect each other nonlinearly: each neuron's activity at time $t+1$ consists of a linearly weighted sum of all neural activities at time $t$, with added noise, passed through a nonlinearity:\n",
    "\n",
    "$$\n",
    "\\vec{x}_{t+1} = \\sigma(A\\vec{x}_t + \\epsilon_t), \n",
    "$$\n",
    "\n",
    "- $\\vec{x}_t$ is an $n$-dimensional vector representing our $n$-neuron system at timestep $t$\n",
    "- $\\sigma$ is a sigmoid nonlinearity\n",
    "- $A$ is our $n \\times n$ *causal ground truth connectivity matrix* (more on this later)\n",
    "- $\\epsilon_t$ is random noise: $\\epsilon_t \\sim N(\\vec{0}, I_n)$\n",
    "- $\\vec{x}_0$ is initialized to $\\vec{0}$\n",
    "\n",
    "$A$ is a connectivity matrix, so the element $A_{ij}$ represents the causal effect of neuron $i$ on neuron $j$. In our system, neurons will receive connections from only 10% of the whole population on average.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Section 2.2: Visualize  true connectivity\n",
    "\n",
    "We will create a connectivity matrix between 6 neurons and visualize it in two different ways: as a graph with directional edges between connected neurons and as an image of the connectivity matrix.\n",
    "\n",
    "*Check your understanding*: do you understand how the left plot relates to the right plot below?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 364
    },
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:22:27.730204Z",
     "iopub.status.busy": "2021-05-25T01:22:27.723221Z",
     "iopub.status.idle": "2021-05-25T01:22:28.101383Z",
     "shell.execute_reply": "2021-05-25T01:22:28.100895Z"
    },
    "outputId": "bc378d46-4d05-46bf-f509-856685adc052"
   },
   "outputs": [],
   "source": [
    "#@markdown Execute this cell to visualize connectivity\n",
    "\n",
    "## Initializes the system\n",
    "n_neurons = 6\n",
    "A = create_connectivity(n_neurons) # we are invoking a helper function that generates our nxn causal connectivity matrix.\n",
    "\n",
    "# Let's plot it!\n",
    "plot_connectivity_graph_matrix(A)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Exercise 2: System simulation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "In this exercise we're going to simulate the system. Please complete the following function so that at every timestep the activity vector $x$ is updated according to:\n",
    "$$\n",
    "\\vec{x}_{t+1} = \\sigma(A\\vec{x}_t + \\epsilon_t).\n",
    "$$\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:22:28.107755Z",
     "iopub.status.busy": "2021-05-25T01:22:28.107148Z",
     "iopub.status.idle": "2021-05-25T01:22:28.110743Z",
     "shell.execute_reply": "2021-05-25T01:22:28.110260Z"
    }
   },
   "outputs": [],
   "source": [
    "def simulate_neurons(A, timesteps, random_state=42):\n",
    "    \"\"\"Simulates a dynamical system for the specified number of neurons and timesteps.\n",
    "\n",
    "    Args:\n",
    "        A (np.array): the connectivity matrix\n",
    "        timesteps (int): the number of timesteps to simulate our system.\n",
    "        random_state (int): random seed for reproducibility\n",
    "\n",
    "    Returns:\n",
    "        - X has shape (n_neurons, timeteps). A schematic:\n",
    "                   ___t____t+1___\n",
    "       neuron  0  |   0    1     |\n",
    "                  |   1    0     |\n",
    "       neuron  i  |   0 -> 1     |\n",
    "                  |   0    0     |\n",
    "                  |___1____0_____|\n",
    "\n",
    "    \"\"\"\n",
    "    np.random.seed(random_state)\n",
    "\n",
    "    n_neurons = len(A)\n",
    "    X = np.zeros((n_neurons, timesteps))\n",
    "\n",
    "    for t in range(timesteps - 1):\n",
    "\n",
    "        # Create noise vector\n",
    "        epsilon = np.random.multivariate_normal(np.zeros(n_neurons), np.eye(n_neurons))\n",
    "\n",
    "        ########################################################################\n",
    "        ## TODO: Fill in the update rule for our dynamical system.\n",
    "        ## Fill in function and remove\n",
    "        raise NotImplementedError(\"Complete simulate_neurons\")\n",
    "        ########################################################################\n",
    "\n",
    "        # Update activity vector for next step\n",
    "        X[:, t + 1] = sigmoid(...)  # we are using helper function sigmoid\n",
    "\n",
    "    return X\n",
    "\n",
    "\n",
    "# Set simulation length\n",
    "timesteps = 5000\n",
    "\n",
    "# Uncomment below to test your function\n",
    "\n",
    "# Simulate our dynamical system\n",
    "# X = simulate_neurons(A, timesteps)\n",
    "\n",
    "# plot_neural_activity(X)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 353
    },
    "colab_type": "text",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:22:28.117855Z",
     "iopub.status.busy": "2021-05-25T01:22:28.117270Z",
     "iopub.status.idle": "2021-05-25T01:22:28.999605Z",
     "shell.execute_reply": "2021-05-25T01:22:28.999088Z"
    },
    "outputId": "2481d229-c98a-4fce-a26d-a4c18fa08d74"
   },
   "source": [
    "[*Click for solution*](https://github.com/NeuromatchAcademy/course-content/tree/master//tutorials/W3D5_NetworkCausality/solutions/W3D5_Tutorial1_Solution_b2fb6587.py)\n",
    "\n",
    "*Example output:*\n",
    "\n",
    "<img alt='Solution hint' align='left' width=557 height=343 src=https://raw.githubusercontent.com/NeuromatchAcademy/course-content/master/tutorials/W3D5_NetworkCausality/static/W3D5_Tutorial1_Solution_b2fb6587_0.png>\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "---\n",
    "# Section 3: Recovering connectivity through perturbation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 537
    },
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:22:29.004866Z",
     "iopub.status.busy": "2021-05-25T01:22:29.004315Z",
     "iopub.status.idle": "2021-05-25T01:22:29.038139Z",
     "shell.execute_reply": "2021-05-25T01:22:29.038709Z"
    },
    "outputId": "d971439e-5408-48ad-b933-d28e55d86332"
   },
   "outputs": [],
   "source": [
    "#@title Video 3: Perturbing systems\n",
    "# Insert the ID of the corresponding youtube video\n",
    "from IPython.display import YouTubeVideo\n",
    "video = YouTubeVideo(id=\"wOZunGtuqQE\", width=854, height=480, fs=1)\n",
    "print(\"Video available at https://youtu.be/\" + video.id)\n",
    "video"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Section 3.1: Random perturbation in our system of neurons\n",
    "\n",
    "We want to get the causal effect of each neuron upon each other neuron. The ground truth of the causal effects is the connectivity matrix $A$.\n",
    "\n",
    "Remember that we would like to calculate:\n",
    "$$\n",
    "\\delta_{A\\to B} = \\mathbb{E}[B | A=1] -  \\mathbb{E}[B | A=0] \n",
    "$$\n",
    "\n",
    "\n",
    "We'll do this by randomly setting the system state to 0 or 1 and observing the outcome after one timestep. If we do this $N$ times, the effect of neuron $i$ upon neuron $j$ is:\n",
    "$$\n",
    "\\delta_{x^i\\to x^j} \\approx \\frac1N \\sum_i^N[x_{t+1}^j | x^i_t=1] -  \\frac1N \\sum_i^N[x_{t+1}^j | x^i_t=0]\n",
    "$$\n",
    "This is just the average difference of the activity of neuron $j$ in the two conditions.\n",
    "\n",
    "We are going to calculate the above equation, but imagine it like *intervening* in activity every other timestep."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "We will use helper function `simulate_neurons_perturb`. While the rest of the function is the same as the ``simulate_neurons`` function in the previous exercise, every time step we now additionally include:\n",
    "```\n",
    "if t % 2 == 0:\n",
    "    X[:,t] = np.random.choice([0,1], size=n_neurons)\n",
    "```\n",
    "\n",
    "This means that at every other timestep,  every neuron's activity is changed to either 0 or 1. \n",
    "\n",
    "Pretty serious perturbation, huh? You don't want that going on in your brain.\n",
    "\n",
    "**Now visually compare the dynamics:** Run this next cell and see if you can spot how the dynamics have changed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 286
    },
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:22:29.046387Z",
     "iopub.status.busy": "2021-05-25T01:22:29.045809Z",
     "iopub.status.idle": "2021-05-25T01:22:30.134508Z",
     "shell.execute_reply": "2021-05-25T01:22:30.134948Z"
    },
    "outputId": "e4fadc39-d058-45d7-a285-42027805bcad"
   },
   "outputs": [],
   "source": [
    "# @markdown Execute this cell to visualize perturbed dynamics\n",
    "\n",
    "timesteps = 5000  # Simulate for 5000 timesteps.\n",
    "\n",
    "# Simulate our dynamical system for the given amount of time\n",
    "X_perturbed = simulate_neurons_perturb(A, timesteps)\n",
    "\n",
    "# Plot our standard versus perturbed dynamics\n",
    "fig, axs = plt.subplots(1, 2, figsize=(15, 4))\n",
    "im0 = axs[0].imshow(X[:, :10])\n",
    "im1 = axs[1].imshow(X_perturbed[:, :10])\n",
    "\n",
    "# Matplotlib boilerplate code\n",
    "divider = make_axes_locatable(axs[0])\n",
    "cax0 = divider.append_axes(\"right\", size=\"5%\", pad=0.15)\n",
    "plt.colorbar(im0, cax=cax0)\n",
    "\n",
    "divider = make_axes_locatable(axs[1])\n",
    "cax1 = divider.append_axes(\"right\", size=\"5%\", pad=0.15)\n",
    "plt.colorbar(im1, cax=cax1)\n",
    "\n",
    "axs[0].set_ylabel(\"Neuron\", fontsize=15)\n",
    "axs[1].set_xlabel(\"Timestep\", fontsize=15)\n",
    "axs[0].set_xlabel(\"Timestep\", fontsize=15);\n",
    "axs[0].set_title(\"Standard dynamics\", fontsize=15)\n",
    "axs[1].set_title(\"Perturbed dynamics\", fontsize=15);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 537
    },
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:22:30.141041Z",
     "iopub.status.busy": "2021-05-25T01:22:30.140486Z",
     "iopub.status.idle": "2021-05-25T01:22:30.177169Z",
     "shell.execute_reply": "2021-05-25T01:22:30.176055Z"
    },
    "outputId": "52e169ed-0420-4001-f11b-b6a5e0d7f486"
   },
   "outputs": [],
   "source": [
    "#@title Video 4: Calculating causality\n",
    "# Insert the ID of the corresponding youtube video\n",
    "from IPython.display import YouTubeVideo\n",
    "video = YouTubeVideo(id=\"EDZtcsIAVGM\", width=854, height=480, fs=1)\n",
    "print(\"Video available at https://youtu.be/\" + video.id)\n",
    "video"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Exercise 3: Using perturbed dynamics to recover connectivity\n",
    "\n",
    "From the above perturbed dynamics, write a function that recovers the causal effect of a given single neuron (`selected_neuron`) upon all other neurons in the system. Remember from above you're calculating:\n",
    "$$\n",
    "\\delta_{x^i\\to x^j} \\approx \\frac1N \\sum_i^N[x_{t+1}^j | x^i_t=1] -  \\frac1N \\sum_i^N[x_{t+1}^j | x^i_t=0]] \n",
    "$$\n",
    "\n",
    "\n",
    "Recall that we perturbed every neuron at every other timestep. Despite perturbing every neuron, in this exercise we are concentrating on computing the causal effect of a single neuron (we will look at all neurons effects on all neurons next). We want to exclusively use the timesteps without perturbation for $x^j_{t+1}$ and the timesteps with perturbation for $x^j_{t}$ in the formulas above. In numpy, indexing occurs as `array[ start_index : end_index : count_by]`. So getting every other element in an array (such as every other timestep) is as easy as `array[::2]`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:22:30.184814Z",
     "iopub.status.busy": "2021-05-25T01:22:30.183210Z",
     "iopub.status.idle": "2021-05-25T01:22:30.810998Z",
     "shell.execute_reply": "2021-05-25T01:22:30.810447Z"
    }
   },
   "outputs": [],
   "source": [
    "def get_perturbed_connectivity_from_single_neuron(perturbed_X, selected_neuron):\n",
    "    \"\"\"\n",
    "    Computes the connectivity matrix from the selected neuron using differences in means.\n",
    "\n",
    "    Args:\n",
    "        perturbed_X (np.ndarray): the perturbed dynamical system matrix of shape (n_neurons, timesteps)\n",
    "        selected_neuron (int): the index of the neuron we want to estimate connectivity for\n",
    "\n",
    "    Returns:\n",
    "        estimated_connectivity (np.ndarray): estimated connectivity for the selected neuron, of shape (n_neurons,)\n",
    "    \"\"\"\n",
    "    # Extract the perturbations of neuron 1 (every other timestep)\n",
    "    neuron_perturbations = perturbed_X[selected_neuron, ::2]\n",
    "\n",
    "    # Extract the observed outcomes of all the neurons (every other timestep)\n",
    "    all_neuron_output = perturbed_X[:, 1::2]\n",
    "\n",
    "    # Initialize estimated connectivity matrix\n",
    "    estimated_connectivity = np.zeros(n_neurons)\n",
    "\n",
    "    # Loop over neurons\n",
    "    for neuron_idx in range(n_neurons):\n",
    "\n",
    "        # Get this output neurons (neuron_idx) activity\n",
    "        this_neuron_output = all_neuron_output[neuron_idx, :]\n",
    "\n",
    "        # Get timesteps where the selected neuron == 0 vs == 1\n",
    "        one_idx = np.argwhere(neuron_perturbations == 1)\n",
    "        zero_idx = np.argwhere(neuron_perturbations == 0)\n",
    "\n",
    "        ########################################################################\n",
    "        ## TODO: Insert your code here to compute the neuron activation from perturbations.\n",
    "        # Fill out function and remove\n",
    "        raise NotImplementedError(\"Complete the function get_perturbed_connectivity_single_neuron\")\n",
    "        ########################################################################\n",
    "\n",
    "        difference_in_means = ...\n",
    "\n",
    "        estimated_connectivity[neuron_idx] = difference_in_means\n",
    "\n",
    "    return estimated_connectivity\n",
    "\n",
    "\n",
    "# Initialize the system\n",
    "n_neurons = 6\n",
    "timesteps = 5000\n",
    "selected_neuron = 1\n",
    "\n",
    "# Simulate our perturbed dynamical system\n",
    "perturbed_X = simulate_neurons_perturb(A, timesteps)\n",
    "\n",
    "\n",
    "## Uncomment below to test your function\n",
    "\n",
    "# Measure connectivity of neuron 1\n",
    "# estimated_connectivity = get_perturbed_connectivity_from_single_neuron(perturbed_X, selected_neuron)\n",
    "\n",
    "# plot_true_vs_estimated_connectivity(estimated_connectivity, A, selected_neuron)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 357
    },
    "colab_type": "text",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:22:30.819788Z",
     "iopub.status.busy": "2021-05-25T01:22:30.818951Z",
     "iopub.status.idle": "2021-05-25T01:22:31.850011Z",
     "shell.execute_reply": "2021-05-25T01:22:31.850450Z"
    },
    "outputId": "87b590c6-662b-4564-b5ae-67362249bb44"
   },
   "source": [
    "[*Click for solution*](https://github.com/NeuromatchAcademy/course-content/tree/master//tutorials/W3D5_NetworkCausality/solutions/W3D5_Tutorial1_Solution_b51df5f6.py)\n",
    "\n",
    "*Example output:*\n",
    "\n",
    "<img alt='Solution hint' align='left' width=486 height=341 src=https://raw.githubusercontent.com/NeuromatchAcademy/course-content/master/tutorials/W3D5_NetworkCausality/static/W3D5_Tutorial1_Solution_b51df5f6_0.png>\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "We can quantify how close our estimated connectivity matrix is to our true connectivity matrix by correlating them. We should see almost perfect correlation between our estimates and the true connectivity - do we?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:22:31.855572Z",
     "iopub.status.busy": "2021-05-25T01:22:31.854662Z",
     "iopub.status.idle": "2021-05-25T01:22:31.859372Z",
     "shell.execute_reply": "2021-05-25T01:22:31.858855Z"
    },
    "outputId": "cf9e22dc-1612-4117-d6d2-b499503e3c6e"
   },
   "outputs": [],
   "source": [
    "# Correlate true vs estimated connectivity matrix\n",
    "np.corrcoef(A[:, selected_neuron], estimated_connectivity)[1, 0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "**Note on interpreting A**: Strictly speaking, $A$ is not the matrix of causal effects but rather the dynamics matrix. So why compare them like this? The answer is that $A$ and the effect matrix both are $0$ everywhere except where there is a directed connection. So they should have a correlation of $1$ if we estimate the effects correctly. (Their scales, however, are different. This in part because the nonlinearity $\\sigma$ squashes the values of $x$ to $[0,1]$.) See the Appendix after Tutorial 2 for more discussion of using correlation as a metric."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Section 3.2: Measuring how perturbations recover the entire connectivity matrix\n",
    "\n",
    "Nice job! You just estimated connectivity for a single neuron.\n",
    "\n",
    "We're now going to use the same strategy for all neurons at once. We provide this helper function `get_perturbed_connectivity_all_neurons`.  If you're curious about how this works and have extra time, scroll to the explanation at the bottom.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:22:31.863622Z",
     "iopub.status.busy": "2021-05-25T01:22:31.863049Z",
     "iopub.status.idle": "2021-05-25T01:22:32.510576Z",
     "shell.execute_reply": "2021-05-25T01:22:32.510055Z"
    }
   },
   "outputs": [],
   "source": [
    "# Parameters\n",
    "n_neurons = 6\n",
    "timesteps = 5000\n",
    "\n",
    "# Generate nxn causal connectivity matrix\n",
    "A = create_connectivity(n_neurons)\n",
    "\n",
    "# Simulate perturbed dynamical system\n",
    "perturbed_X = simulate_neurons_perturb(A, timesteps)\n",
    "\n",
    "# Get estimated connectivity matrix\n",
    "R = get_perturbed_connectivity_all_neurons(perturbed_X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 715
    },
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:22:32.532505Z",
     "iopub.status.busy": "2021-05-25T01:22:32.531915Z",
     "iopub.status.idle": "2021-05-25T01:22:33.289447Z",
     "shell.execute_reply": "2021-05-25T01:22:33.288973Z"
    },
    "outputId": "d17d6429-7e68-4493-b140-c83ec76e6116"
   },
   "outputs": [],
   "source": [
    "#@markdown Execute this cell to visualize true vs estimated connectivity\n",
    "\n",
    "# Let's visualize the true connectivity and estimated connectivity together\n",
    "fig, axs = plt.subplots(1, 2, figsize=(10, 5))\n",
    "see_neurons(A, axs[0]) # we are invoking a helper function that visualizes the connectivity matrix\n",
    "plot_connectivity_matrix(A, ax=axs[1])\n",
    "plt.suptitle(\"True connectivity matrix A\");\n",
    "plt.show()\n",
    "fig, axs = plt.subplots(1,2, figsize=(10,5))\n",
    "see_neurons(R.T,axs[0]) # we are invoking a helper function that visualizes the connectivity matrix\n",
    "plot_connectivity_matrix(R.T, ax=axs[1])\n",
    "plt.suptitle(\"Estimated connectivity matrix R\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "We can again calculate the correlation coefficient between the elements of the two matrices. As you can see from the cell below, we do a good job recovering the true causality of the system!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:22:33.295245Z",
     "iopub.status.busy": "2021-05-25T01:22:33.294633Z",
     "iopub.status.idle": "2021-05-25T01:22:33.300388Z",
     "shell.execute_reply": "2021-05-25T01:22:33.299816Z"
    },
    "outputId": "2f4b954f-5257-420a-dac2-81270715d7b3"
   },
   "outputs": [],
   "source": [
    "np.corrcoef(A.transpose().flatten(), R.flatten())[1, 0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "---\n",
    "# Summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 537
    },
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:22:33.306200Z",
     "iopub.status.busy": "2021-05-25T01:22:33.305623Z",
     "iopub.status.idle": "2021-05-25T01:22:33.341672Z",
     "shell.execute_reply": "2021-05-25T01:22:33.341176Z"
    },
    "outputId": "0e1c5d5f-1a63-4b5e-b965-209abf3191de"
   },
   "outputs": [],
   "source": [
    "#@title Video 5: Summary\n",
    "# Insert the ID of the corresponding youtube video\n",
    "from IPython.display import YouTubeVideo\n",
    "video = YouTubeVideo(id=\"p3fZW5Woqa4\", width=854, height=480, fs=1)\n",
    "print(\"Video available at https://youtu.be/\" + video.id)\n",
    "video"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "In this tutorial, we learned about how to define and estimate causality using pertubations. In particular we:\n",
    "\n",
    "1) Learned how to simulate a system of connected neurons\n",
    "\n",
    "2) Learned how to estimate the connectivity between neurons by directly perturbing neural activity"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "---\n",
    "# Further resources for today\n",
    "\n",
    "If you are interested in causality after NMA ends, here are some useful texts to consult.\n",
    "\n",
    "\n",
    "*   *Causal Inference for Statistics, Social, and Biomedical Sciences* by Imbens and Rubin\n",
    "*   *Causal Inference: What If* by Hernan and Rubin\n",
    "*   *Mostly Harmless Econometrics* by Angrist and Pischke\n",
    "*   https://www.nature.com/articles/s41562-018-0466-5 for application to neuroscience\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "---\n",
    "# Appendix\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Computation of the estimated connectivity matrix\n",
    "\n",
    "**This is an explanation of what the code is doing in `get_perturbed_connectivity_all_neurons()`**\n",
    "\n",
    "First, we compute an estimated connectivity matrix $R$. We extract\n",
    "perturbation matrix $P$ and outcomes matrix $O$:\n",
    "\n",
    "$$\n",
    "P = \\begin{bmatrix}\n",
    "\\mid & \\mid & ... & \\mid \\\\ \n",
    "x_0  & x_2  & ... & x_T  \\\\ \n",
    "\\mid & \\mid & ... & \\mid\n",
    "\\end{bmatrix}_{n \\times T/2}\n",
    "$$\n",
    "\n",
    "$$\n",
    "O = \\begin{bmatrix}\n",
    "\\mid & \\mid & ... & \\mid \\\\ \n",
    "x_1  & x_3  & ... & x_{T-1}  \\\\ \n",
    "\\mid & \\mid & ... & \\mid\n",
    "\\end{bmatrix}_{n \\times T/2}\n",
    "$$\n",
    "\n",
    "And calculate the correlation of matrix $S$, which is $P$ and $O$ stacked on each other:\n",
    "\n",
    "$$\n",
    "S = \\begin{bmatrix}\n",
    "P  \\\\ \n",
    "O\n",
    "\\end{bmatrix}_{2n \\times T/2}\n",
    "$$\n",
    "\n",
    "We then extract $R$ as the upper right $n \\times n$ block of $corr(S)$:\n",
    "\n",
    "\n",
    "This is because the upper right block corresponds to the estimated perturbation effect on outcomes for each pair of neurons in our system.\n",
    "\n",
    "This method gives an estimated connectivity matrix that is the proportional to the result you would obtain with differences in means, and differs only in a proportionality constant that depends on the variance of $x$"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "include_colab_link": true,
   "name": "W3D5_Tutorial1",
   "provenance": [],
   "toc_visible": true
  },
  "kernel": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
